Skip to content

Bayesian Neural Network¤

We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.

We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.

# Import libraries
using Lux
using Turing, Plots, Random, ReverseDiff, NNlib, Functors

# Hide sampling progress
Turing.setprogress!(false);

# Use reverse_diff due to the number of parameters in neural networks
Turing.setadbackend(:reversediff)
:reversediff

Generating data¤

Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.

# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)

# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))

x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))

# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]

# Plot data points
function plot_data()
    x1 = first.(xt1s)
    y1 = last.(xt1s)
    x2 = first.(xt0s)
    y2 = last.(xt0s)

    plt = Plots.scatter(x1, y1; color="red", clim=(0, 1))
    Plots.scatter!(plt, x2, y2; color="blue", clim=(0, 1))

    return plt
end

plot_data()

Building the Neural Network¤

The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.

The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.

# Construct a neural network using Lux
nn = Chain(Dense(2, 3, tanh), Dense(3, 2, tanh), Dense(2, 1, sigmoid))

# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)

Lux.parameterlength(nn) # number of paraemters in NN
20

The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).

# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335

Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.

function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
    @assert length(ps_new) == Lux.parameterlength(ps)
    i = 1
    function get_ps(x)
        z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
        i += length(x)
        return z
    end
    return fmap(get_ps, ps)
end

# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
    global st

    # Sample the parameters
    nparameters = Lux.parameterlength(nn)
    parameters ~ MvNormal(zeros(nparameters), sig .* ones(nparameters))

    # Forward NN to make predictions
    preds, st = nn(xs, vector_to_parameters(parameters, ps), st)

    # Observe each prediction.
    for i in 1:length(ts)
        ts[i] ~ Bernoulli(preds[i])
    end
end
bayes_nn (generic function with 2 methods)

Inference can now be performed by calling sample. We use the HMC sampler here.

# Perform inference.
N = 5000
ch = sample(bayes_nn(hcat(xs...), ts), HMC(0.05, 4), N)
Chains MCMC chain (5000×29×1 Array{Float64, 3}):

Iterations        = 1:1:5000
Number of chains  = 1
Samples per chain = 5000
Wall duration     = 90.47 seconds
Compute duration  = 90.47 seconds
parameters        = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, step_size, nom_step_size

Summary Statistics
      parameters      mean       std   naive_se      mcse       ess      rhat  ⋯
          Symbol   Float64   Float64    Float64   Float64   Float64   Float64  ⋯

   parameters[1]    1.4675    2.3083     0.0326    0.2693   11.7281    2.0767  ⋯
   parameters[2]    5.5552    2.7594     0.0390    0.3231   12.4936    1.1899  ⋯
   parameters[3]    0.0273    0.7531     0.0107    0.0774   18.6028    1.1024  ⋯
   parameters[4]   -1.8129    1.6025     0.0227    0.1796   15.6117    1.0448  ⋯
   parameters[5]    0.7135    1.3562     0.0192    0.1557   13.5410    1.1842  ⋯
   parameters[6]    4.7089    1.7599     0.0249    0.1976   18.8688    1.0080  ⋯
   parameters[7]   -5.0105    3.4342     0.0486    0.4005   12.7796    1.0072  ⋯
   parameters[8]    0.1951    2.3736     0.0336    0.2763   12.0470    1.4253  ⋯
   parameters[9]    0.8814    1.6137     0.0228    0.1798   19.0521    1.0278  ⋯
  parameters[10]   -0.9196    4.1529     0.0587    0.4896   10.7452    2.1730  ⋯
  parameters[11]   -0.0995    2.6126     0.0369    0.3022   12.2717    1.2131  ⋯
  parameters[12]   -2.0922    3.0141     0.0426    0.3552   11.4220    1.6551  ⋯
  parameters[13]    4.3457    1.7685     0.0250    0.2022   14.9620    1.0385  ⋯
  parameters[14]   -2.9048    1.6752     0.0237    0.1877   15.8258    1.0980  ⋯
  parameters[15]    2.1551    1.8145     0.0257    0.2082   13.7275    1.2178  ⋯
  parameters[16]   -3.2932    1.4081     0.0199    0.1561   19.8148    1.0165  ⋯
  parameters[17]   -3.4255    2.6896     0.0380    0.3158   12.6164    1.2372  ⋯
        ⋮             ⋮         ⋮         ⋮          ⋮         ⋮         ⋮     ⋱
                                                     1 column and 3 rows omitted

Quantiles
      parameters       2.5%     25.0%     50.0%     75.0%     97.5%
          Symbol    Float64   Float64   Float64   Float64   Float64

   parameters[1]    -2.8737   -0.2060    1.2397    3.2028    6.3763
   parameters[2]     0.7006    3.3812    5.8509    7.5387   11.0436
   parameters[3]    -2.0910   -0.2111    0.1039    0.4571    1.0501
   parameters[4]    -5.8984   -2.5130   -1.6280   -0.7613    0.8443
   parameters[5]    -0.7348   -0.0564    0.4240    0.9659    5.2736
   parameters[6]     1.6909    3.4942    4.5379    5.9346    8.2718
   parameters[7]   -11.1144   -7.4680   -5.2027   -2.8930    1.9571
   parameters[8]    -4.3334   -1.5037    0.2996    2.0048    4.4496
   parameters[9]    -1.9076   -0.1331    0.6428    1.7236    4.4656
  parameters[10]    -8.4372   -4.2836    0.1960    2.2765    5.9196
  parameters[11]    -5.1766   -1.7353    0.2017    1.6669    4.5451
  parameters[12]    -6.9702   -4.6678   -2.4274    0.6276    3.0518
  parameters[13]     1.2283    3.0997    4.1223    5.5624    8.0027
  parameters[14]    -6.2349   -4.0604   -2.9481   -1.8547    0.5993
  parameters[15]    -1.9406    1.1370    2.4385    3.4656    5.0130
  parameters[16]    -5.6563   -4.2682   -3.5360   -2.4030    0.0173
  parameters[17]    -9.2556   -4.9700   -3.1804   -1.7896    1.5551
        ⋮             ⋮          ⋮         ⋮         ⋮         ⋮
                                                       3 rows omitted

Now we extract the parameter samples from the sampled chain as theta (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.

# Extract all weight and bias parameters.
theta = MCMCChains.group(ch, :parameters).value;

Prediction Visualization¤

# A helper to run the nn through data `x` using parameters `theta`
nn_forward(x, theta) = nn(x, vector_to_parameters(theta, ps), st)[1]

# Plot the data we have.
plot_data()

# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])

# Extract the max row value from i.
i = i.I[1]

# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z)

The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.

\[ p(\tilde{x} | X, \alpha) = \int_{\theta} p(\tilde{x} | \theta) p(\theta | X, \alpha) \approx \sum_{\theta \sim p(\theta | X, \alpha)}f_{\theta}(\tilde{x}) \]

The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.

# Return the average predicted value across multiple weights.
function nn_predict(x, theta, num)
    return mean([nn_forward(x, view(theta, i, :))[1] for i in 1:10:num])
end
nn_predict (generic function with 1 method)

Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.

Plot the average prediction.

plot_data()

n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], theta, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z)

<polyline clip-path="url(#clip852)" style="stroke:#f8c931; stroke-linecap:round; stroke-linejoin:round; stroke-width:4; stroke-opacity:1; fill:none" points=" 248.242,995.572 251.669,993.136 323.668,971.121 370.328,936.563 399.093,924.706 463.689,879.991 474.518,874.437 549.943,864.228 606.94,823.418 625.369,815.618 700.794,797.038 757.435,766.846 776.219,763.058 851.645,752.915 927.07,737.862 955.127,766.846 1002.5,814.962 1005.97,823.418 1024.16,879.991 1038.49,936.563 1057.52,993.136 1077.92,1027.06 1081.49,1049.71 1089.16,1106.28 1092,1162.85 1092.49,1219.43 1093.38,1276 1092.33,1332.57 1088.78,1389.14 1086.71,1445.72

"/>

Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 1,000.

# Number of iterations to plot.
n_end = 1000

anim = @gif for i in 1:n_end
    plot_data()
    Z = [nn_forward([x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range]
    contour!(x1_range, x2_range, Z; title="Iteration $i", clim=(0, 1))
end every 5


This page was generated using Literate.jl.